import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F

device = t.device("cuda" if t.cuda.is_available() else "cpu")

class PIDController:
    def __init__(self, Kp, Ki, Kd, target):
        self.Kp = Kp
        self.Ki = Ki
        self.Kd = Kd
        self.target = target
        self.cumulate = np.ones(24) * 0.5
        self.integral = 0
        self.prev_error = 0

    def update(self, current_value):
        error = self.target - current_value
        self.integral += error
        self.cumulate[:-1] = self.cumulate[1:]
        self.cumulate[-1] = error
        derivative = error - self.prev_error
        output = self.Kp * error + self.Ki * np.mean(self.cumulate) + self.Kd * derivative
        self.prev_error = error
        return output

class Generator:
    def __init__(self, args):
        self.args = args

    def generate_uniform(self, low, high):
        num_instances = self.args.num_sample_train
        num_agent = self.args.num_agent
        return np.random.uniform(low, high, [num_instances, num_agent])

    def generate_ctr(self, low, high):
        num_instances = self.args.num_sample_train
        return np.random.uniform(low, high, [num_instances])

    def generate_uniform2(self, low, high):
        num_instances = self.args.num_sample_train
        num_agent = self.args.num_agent
        sample_val = np.random.uniform(low, high, [num_instances, num_agent])
        sample_val[:, 0] = np.random.uniform(0, 1, [num_instances])
        return sample_val

    def generate_asymmetry(self):
        num_instances = self.args.num_sample_train
        num_agent = self.args.num_agent
        sample_val = np.random.uniform(0, 1, [num_instances, num_agent])
        sample_val[-num_instances // 2:-1, :] = np.random.uniform(0, 2, [num_instances // 2, num_agent])
        return sample_val

def get_position_ctr(position, choice):
    ctr = np.zeros(position.shape)
    if choice == 2:
        ctr[position == 0] = 1.0
        ctr[position == 1] = 0.8
        ctr[position == 2] = 0.6
        ctr[position == 3] = 0.5
    elif choice == 1:
        ctr[position == 0] = 1.0
        ctr[position == 1] = 0.8
        ctr[position == 2] = 0.6
        ctr[position == 3] = 0.5
    return ctr

def torch_get_position_ctr(position, choice):
    ctr = t.zeros(position.shape).float().to(device)
    if choice == 2:
        ctr[position == 0] = 1.0
        ctr[position == 1] = 0.8
        ctr[position == 2] = 0.6
        ctr[position == 3] = 0.5
    elif choice == 1:
        ctr[position == 0] = 1.0
        ctr[position == 1] = 0.8
        ctr[position == 2] = 0.6
        ctr[position == 3] = 0.5
    return ctr

def deterministic_NeuralSort(s, tau):
    n = s.size()[1]
    one = t.ones((n, 1), dtype=t.float32).to(device)
    A_s = t.abs(s - s.permute(0, 2, 1)).float()
    B = t.matmul(A_s, t.matmul(one, t.transpose(one, 0, 1)))
    scaling = (n + 1 - 2 * (t.arange(n) + 1)).type(t.float32).to(device)
    C = t.matmul(s.float(), scaling.unsqueeze(0))
    P_max = (C - B).permute(0, 2, 1)
    sm = nn.Softmax(-1)
    P_hat = sm(P_max / tau)
    return P_hat

def cumulative_average(arr):
    cum_avg = np.zeros(len(arr))
    for k in range(1, len(arr) + 1):
        cum_avg[k - 1] = np.mean(arr[:k])
    return cum_avg

def calculate_t(t):
    t0 = t % 24
    if t0 <= 5:
        return max(1.5 + 0.5 * (t0 / 5) + 0.1 * np.random.normal(), 0.5)
    elif t0 <= 17:
        return max(2.0 - (t0 - 5) / 12 + 0.1 * np.random.normal(), 0.5)
    else:
        return max(1.0 + 0.5 * (t0 - 17) / 6 + 0.1 * np.random.normal(), 0.5)

class MLP(nn.Module):
    def __init__(self, layers, activation, output_type='tanh'):
        super(MLP, self).__init__()
        self.layers_list = nn.ModuleList([nn.Linear(layers[i], layers[i + 1]) for i in range(len(layers) - 1)])
        self.activation = activation
        self.output_type = output_type

    def forward(self, x):
        for j, layer in enumerate(self.layers_list):
            if j == len(self.layers_list) - 1:
                if self.output_type == 'tanh':
                    x = 5 * t.tanh(layer(x))
                elif self.output_type == 'sigmoid':
                    x = 5 * t.sigmoid(layer(x))
            else:
                x = self.activation(layer(x))
        return x

def linear_interpolation2(x, A):
    A = A[A[:, 0].argsort()]
    x_coord = x[0]

    if x_coord <= A[0, 0]:
        x1, y1 = A[0]
        x2, y2 = A[1]

    elif x_coord >= A[-1, 0]:
        x1, y1 = A[-2]
        x2, y2 = A[-1]
    else:

        for i in range(len(A) - 1):
            if A[i, 0] <= x_coord <= A[i + 1, 0]:
                x1, y1 = A[i]
                x2, y2 = A[i + 1]
                break

    y_coord = y1 + (y2 - y1) * (x_coord - x1) / (x2 - x1)
    return x_coord, y_coord

def linear_interpolation(x, A):

    X = A[:, 0]
    Y = A[:, 1]

    A = np.vstack([X, np.ones(len(X))]).T
    m, c = np.linalg.lstsq(A, Y, rcond=None)[0]

    x_new = x[0]
    y_new = m * x_new + c
    
    return x_new, y_new

def calculate_distances(x, y, A, B, C, D):

    all_points = np.vstack((x, y, A, B, C, D))
    
    max_x = np.max(all_points[:, 0])
    max_y = np.max(all_points[:, 1])
    k = np.array([max_x, max_y])

    norm_x = all_points[:, 0] / max_x
    norm_y = all_points[:, 1] / max_y
    norm_all_points = np.column_stack((norm_x, norm_y))
    norm_k = np.array([1, 1])

    dist_x = np.linalg.norm(norm_k - norm_all_points[0])
    dist_y = np.linalg.norm(norm_k - norm_all_points[1])

    min_dist_A = np.min(np.linalg.norm(norm_all_points[2:2+len(A)] - norm_k, axis=1))

    min_dist_B = np.min(np.linalg.norm(norm_all_points[2+len(A):2+len(A)+len(B)] - norm_k, axis=1))

    min_dist_C = np.min(np.linalg.norm(norm_all_points[2+len(A)+len(B):2+len(A)+len(B)+len(C)] - norm_k, axis=1))

    min_dist_D = np.min(np.linalg.norm(norm_all_points[2+len(A)+len(B)+len(C):] - norm_k, axis=1))
    
    return dist_x, dist_y, min_dist_A, min_dist_B, min_dist_C, min_dist_D

